package com.zhanghe.study.thread.pool.forkjoin;

import java.util.concurrent.ForkJoinPool;
import java.util.concurrent.ForkJoinTask;
import java.util.concurrent.RecursiveTask;

/**
 * RecursiveTask 任务有返回结果
 *
 * @author zh
 * @date 2022/1/10 10:32
 */
public class TestRecursiveTask {

    public static void main(String[] args) {
        ForkJoinPool pool = new ForkJoinPool();
        MyRecursiveTask task = new MyRecursiveTask(10,20);
        ForkJoinTask result = pool.submit(task);
        System.out.println(result.join());
    }

    static class MyRecursiveTask extends RecursiveTask<Integer> {

        private int begin;
        private int end;

        public MyRecursiveTask(int begin, int end) {
            this.begin = begin;
            this.end = end;
        }

        @Override
        protected Integer compute() {
            if (end - begin > 2) { // fork分解进行计算
                int mid = (begin + end) / 2;
                MyRecursiveTask leftTask = new MyRecursiveTask(begin, mid);
                MyRecursiveTask rightTask = new MyRecursiveTask(mid + 1, end);
                invokeAll(leftTask, rightTask);
                return leftTask.join() + rightTask.join();
            } else {
                int result = 0;
                for (int i = begin; i <= end; i++) {
                    result = i + result;
                }
                System.out.println(begin + "~" + end + "累加结果为"+result);
                return result;
            }

        }
    }
}
